package scanner

import (
	"context"
	"testing"
	"time"

	malysisv1pb "buf.build/gen/go/safedep/api/protocolbuffers/go/safedep/messages/malysis/v1"
	packagev1 "buf.build/gen/go/safedep/api/protocolbuffers/go/safedep/messages/package/v1"
	malysisv1 "buf.build/gen/go/safedep/api/protocolbuffers/go/safedep/services/malysis/v1"
	"github.com/safedep/dry/adapters"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"google.golang.org/grpc"

	"github.com/safedep/vet/pkg/models"
)

type mockMalwareAnalysisServiceClient struct {
	mock.Mock
}

func (m *mockMalwareAnalysisServiceClient) QueryPackageAnalysis(
	ctx context.Context,
	in *malysisv1.QueryPackageAnalysisRequest,
	opts ...grpc.CallOption,
) (*malysisv1.QueryPackageAnalysisResponse, error) {
	args := m.Called(ctx, in, opts)
	return args.Get(0).(*malysisv1.QueryPackageAnalysisResponse), args.Error(1)
}

func (m *mockMalwareAnalysisServiceClient) AnalyzePackage(
	ctx context.Context,
	in *malysisv1.AnalyzePackageRequest,
	opts ...grpc.CallOption,
) (*malysisv1.AnalyzePackageResponse, error) {
	args := m.Called(ctx, in, opts)
	return args.Get(0).(*malysisv1.AnalyzePackageResponse), args.Error(1)
}

func (m *mockMalwareAnalysisServiceClient) GetAnalysisReport(
	ctx context.Context,
	in *malysisv1.GetAnalysisReportRequest,
	opts ...grpc.CallOption,
) (*malysisv1.GetAnalysisReportResponse, error) {
	args := m.Called(ctx, in, opts)
	return args.Get(0).(*malysisv1.GetAnalysisReportResponse), args.Error(1)
}

func (m *mockMalwareAnalysisServiceClient) InternalAnalyzePackage(
	ctx context.Context,
	in *malysisv1.InternalAnalyzePackageRequest,
	opts ...grpc.CallOption,
) (*malysisv1.InternalAnalyzePackageResponse, error) {
	args := m.Called(ctx, in, opts)
	return args.Get(0).(*malysisv1.InternalAnalyzePackageResponse), args.Error(1)
}

func (m *mockMalwareAnalysisServiceClient) ListPackageAnalysisRecords(
	ctx context.Context,
	in *malysisv1.ListPackageAnalysisRecordsRequest,
	opts ...grpc.CallOption,
) (*malysisv1.ListPackageAnalysisRecordsResponse, error) {
	args := m.Called(ctx, in, opts)
	return args.Get(0).(*malysisv1.ListPackageAnalysisRecordsResponse), args.Error(1)
}

func (m *mockMalwareAnalysisServiceClient) InternalAgenticAnalyzePackage(
	ctx context.Context,
	in *malysisv1.InternalAgenticAnalyzePackageRequest,
	opts ...grpc.CallOption,
) (*malysisv1.InternalAgenticAnalyzePackageResponse, error) {
	args := m.Called(ctx, in, opts)
	return args.Get(0).(*malysisv1.InternalAgenticAnalyzePackageResponse), args.Error(1)
}

func TestMalysisMalwareAnalysisQueryEnricherEnrich(t *testing.T) {
	testCases := []struct {
		name          string
		pkg           *models.Package
		mockResponse  *malysisv1.QueryPackageAnalysisResponse
		mockError     error
		expectedError bool
	}{
		{
			name: "successful enrichment for npm package",
			pkg: &models.Package{
				PackageDetails: models.NewPackageDetail("npm", "test-package", "1.0.0"),
				Manifest: &models.PackageManifest{
					Ecosystem: models.EcosystemNpm,
				},
			},
			mockResponse: &malysisv1.QueryPackageAnalysisResponse{
				AnalysisId:         "test-analysis-id",
				Report:             &malysisv1pb.Report{},
				VerificationRecord: &malysisv1pb.VerificationRecord{},
			},
			mockError:     nil,
			expectedError: false,
		},
		{
			name: "gRPC call returns error",
			pkg: &models.Package{
				PackageDetails: models.NewPackageDetail("maven", "test-package", "1.0.0"),
				Manifest: &models.PackageManifest{
					Ecosystem: models.EcosystemMaven,
				},
			},
			mockResponse:  nil,
			mockError:     assert.AnError,
			expectedError: true,
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			mockClient := &mockMalwareAnalysisServiceClient{}
			mockGHA := &adapters.GithubClient{}

			req := &malysisv1.QueryPackageAnalysisRequest{
				Target: &malysisv1pb.PackageAnalysisTarget{
					PackageVersion: &packagev1.PackageVersion{
						Package: &packagev1.Package{
							Ecosystem: tc.pkg.GetControlTowerSpecEcosystem(),
							Name:      tc.pkg.GetName(),
						},
						Version: tc.pkg.GetVersion(),
					},
				},
			}

			mockClient.On("QueryPackageAnalysis", mock.Anything, req, mock.Anything).
				Return(tc.mockResponse, tc.mockError)

			enricher := &malysisMalwareAnalysisQueryEnricher{
				client: mockClient,
				config: MalysisMalwareEnricherConfig{
					GrpcOperationTimeout: 2 * time.Second,
				},
				gha: mockGHA,
			}

			err := enricher.Enrich(tc.pkg, func(_ *models.Package) error {
				return nil
			})

			if tc.expectedError {
				assert.Error(t, err)
			} else {
				assert.NoError(t, err)
				assert.NotNil(t, tc.pkg.GetMalwareAnalysisResult())
				assert.Equal(t, tc.mockResponse.GetAnalysisId(), tc.pkg.GetMalwareAnalysisResult().AnalysisId)
				assert.Equal(t, tc.mockResponse.GetReport(), tc.pkg.GetMalwareAnalysisResult().Report)
				assert.Equal(t, tc.mockResponse.GetVerificationRecord(), tc.pkg.GetMalwareAnalysisResult().VerificationRecord)
			}

			mockClient.AssertExpectations(t)
		})
	}
}

func TestNewMalysisMalwareAnalysisQueryEnricher(t *testing.T) {
	// Test cases
	testCases := []struct {
		name          string
		cc            *grpc.ClientConn
		gha           *adapters.GithubClient
		config        MalysisMalwareEnricherConfig
		expectedError bool
	}{
		{
			name:          "nil client connection",
			cc:            nil,
			gha:           &adapters.GithubClient{},
			config:        MalysisMalwareEnricherConfig{},
			expectedError: true,
		},
		{
			name:          "nil github client",
			cc:            &grpc.ClientConn{},
			gha:           nil,
			config:        MalysisMalwareEnricherConfig{},
			expectedError: true,
		},
		{
			name:          "valid inputs",
			cc:            &grpc.ClientConn{},
			gha:           &adapters.GithubClient{},
			config:        MalysisMalwareEnricherConfig{},
			expectedError: false,
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			enricher, err := NewMalysisMalwareAnalysisQueryEnricher(tc.cc, tc.gha, tc.config)

			if tc.expectedError {
				assert.Error(t, err)
				assert.Nil(t, enricher)
			} else {
				assert.NoError(t, err)
				assert.NotNil(t, enricher)
				assert.Equal(t, tc.cc, enricher.cc)
				assert.Equal(t, tc.gha, enricher.gha)
				assert.Equal(t, tc.config, enricher.config)
			}
		})
	}
}
